{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TheGameOfLifeWithGNNs.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": [
        "yi_QHiT8917x",
        "04bXouXy8fqE",
        "IFCHS9LG8cBn",
        "99-apbKL9GV0",
        "iJycUA5I8J0t"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yi_QHiT8917x"
      },
      "source": [
        "\n",
        "#### Copyright 2021 Google LLC. Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "otfMTZ6X9-ab",
        "cellView": "form"
      },
      "source": [
        "#@title License\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "04bXouXy8fqE"
      },
      "source": [
        "# The Game of Life and Graph Neural Networks\n",
        "\n",
        "Conway's Game of Life is a game on an infinite grid with cells. It follows the following rules (from [Wikipedia](https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life)):\n",
        "* Any live cell with two or three live neighbours survives.\n",
        "* Any dead cell with three live neighbours becomes a live cell.\n",
        "* All other live cells die in the next generation. Similarly, all other dead cells stay dead.\n",
        "\n",
        "For the purposes of this tutorial, we will look at the Game of Life on finite-sized grids.\n",
        "\n",
        "We use a GCN model to predict future grid states, given the current grid state!\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IFCHS9LG8cBn"
      },
      "source": [
        "# Installation and Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-85CDV3WyMJl"
      },
      "source": [
        "# Install Sonnet.\n",
        "!pip install dm-sonnet==2.0.0\n",
        "\n",
        "# Remove all TensorBoard packages, and install TensorBoard again.\n",
        "!pip list --format=freeze | grep tensorboard | xargs pip uninstall -y\n",
        "!pip install -q tensorboard\n",
        "%load_ext tensorboard"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9lRtZTQHDeRy"
      },
      "source": [
        "from enum import IntEnum\n",
        "from itertools import product\n",
        "from typing import Iterable, Tuple, List, Callable\n",
        "\n",
        "from IPython.display import HTML\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import networkx as nx\n",
        "import tensorflow as tf\n",
        "import sonnet as snt\n",
        "import seaborn as sns\n",
        "import scipy\n",
        "import random\n",
        "import tqdm\n",
        "import os\n",
        "import datetime as dt\n",
        "from google.colab import drive"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "99-apbKL9GV0"
      },
      "source": [
        "# Experimental Parameters"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X-mgF8pUFQdt"
      },
      "source": [
        "data_generation_seed = 7\n",
        "model_seed = 2\n",
        "num_update_steps = 1\n",
        "overcompleteness = m = 2\n",
        "batch_size = 50\n",
        "num_epochs = 1000\n",
        "early_stopping_lim = 300\n",
        "grid_dims = (8, 8)\n",
        "grid_size = np.prod(grid_dims)\n",
        "num_grids_per_num_alive_cells = 5\n",
        "num_grids = num_grids_per_num_alive_cells * (grid_size//2)\n",
        "model_type = 'gcn-no-norm'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iJycUA5I8J0t"
      },
      "source": [
        "# Directories for Output"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fy5VE0rK0DBG"
      },
      "source": [
        "# Save to Google Drive.\n",
        "GDRIVE_DIR = '/content/gdrive'\n",
        "MYDRIVE_DIR = os.path.join(GDRIVE_DIR, 'MyDrive')\n",
        "MAIN_DIR = os.path.join(MYDRIVE_DIR, 'game-of-life')\n",
        "\n",
        "LOG_DIR = os.path.join(MAIN_DIR, 'logs')\n",
        "CHECKPOINT_DIR = os.path.join(MAIN_DIR, 'checkpoints')\n",
        "SAVED_DIR = os.path.join(MAIN_DIR, 'saved')\n",
        "\n",
        "drive.mount(GDRIVE_DIR)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5l9i-13r_uVl"
      },
      "source": [
        "# Uncomment to clear logs.\n",
        "!rm -rf {LOG_DIR}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oAcEabM_8YPF"
      },
      "source": [
        "# GNN Layer Definitions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c9JtYyArMTPd"
      },
      "source": [
        "* For the GCN, normalization refers to a scaling factor applied to the neighbours embeddings:\n",
        "    * 'degree-i': The GCN variant with degree-normalization:\n",
        "$$\n",
        "    \\text{out}_v = W \\cdot \\sum\\limits_{u \\in \\mathcal{N}(v)}\\frac{h_u}{|\\mathcal{N}(v)|} + B \\cdot h_v + C\n",
        "$$\n",
        "    * 'degree-ij': Original GCN variant from Kipf, et al:\n",
        "$$\n",
        "    \\text{out}_v = W \\cdot \\sum\\limits_{u \\in \\mathcal{N}(v)}\\frac{h_u}{\\sqrt{|\\mathcal{N}(u)||\\mathcal{N}(v)|}} + B \\cdot h_v + C\n",
        "$$\n",
        "    * 'none': The GCN variant with no normalization: \n",
        "$$\n",
        "    \\text{out}_v = W \\cdot \\sum\\limits_{u \\in \\mathcal{N}(v)} h_u + B \\cdot h_v + C\n",
        "$$\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wif0cm6cKDuC"
      },
      "source": [
        "class GCNLayer(snt.Module):\n",
        "  def __init__(self, input_dims, output_dims, normalization='none', name=None):\n",
        "    super(GCNLayer, self).__init__(name=name)\n",
        "    self.w = tf.Variable(tf.random.normal((input_dims, output_dims), dtype=tf.float32), name='w')\n",
        "    self.b = tf.Variable(tf.random.normal((input_dims, output_dims), dtype=tf.float32), name='b')\n",
        "    self.c = tf.Variable(tf.random.normal((output_dims,), dtype=tf.float32), name='c')\n",
        "    self.normalization = normalization\n",
        "\n",
        "    if self.normalization == 'degree-i':\n",
        "      degree_matrix_inv = np.diag(1/np.sum(adjacency_matrix, axis=1))\n",
        "      adjacency = np.matmul(degree_matrix_inv, adjacency_matrix)\n",
        "    elif self.normalization == 'degree-ij':\n",
        "      degree_matrix_inv_sqrt = np.sqrt(np.diag(1/np.sum(adjacency_matrix, axis=1)))\n",
        "      adjacency = np.matmul(np.matmul(degree_matrix_inv_sqrt, adjacency_matrix), degree_matrix_inv_sqrt)\n",
        "    elif self.normalization == 'none':\n",
        "      adjacency = adjacency_matrix\n",
        "    else:\n",
        "      raise ValueError()\n",
        "\n",
        "    self.adjacency = tf.cast(adjacency, tf.float32)\n",
        "\n",
        "  def __call__(self, features):\n",
        "    features = tf.cast(features, tf.float32)\n",
        "    return tf.matmul(tf.matmul(self.adjacency, features), self.w) + tf.matmul(features, self.b) + self.c"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Axc7mTAB8VP3"
      },
      "source": [
        "# Game of Life Logic"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "t45Na6WkDszq"
      },
      "source": [
        "class CellState(IntEnum):\n",
        "  ALIVE = 1\n",
        "  DEAD = 0"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M6VnMFADDGYY"
      },
      "source": [
        "class GameOfLifeGrid:\n",
        "  def __init__(self,\n",
        "               grid_dims: Tuple[int, int]):\n",
        "    if len(grid_dims) != 2:\n",
        "      raise ValueError('Grid must be a 2-tuple of (height, width).')\n",
        "\n",
        "    self.grid = np.zeros(grid_dims, dtype=int)\n",
        "\n",
        "  @staticmethod\n",
        "  def _set_grid_state(grid: np.ndarray,\n",
        "                      state: CellState,\n",
        "                      positions: Iterable[Tuple[int, int]]):\n",
        "\n",
        "    positions = np.array(positions)\n",
        "    if len(positions.shape) == 1:\n",
        "      positions = np.array([positions])\n",
        "\n",
        "    for position in positions:\n",
        "      grid[position[0], position[1]] = state\n",
        "\n",
        "\n",
        "  @staticmethod\n",
        "  def _set_as_alive(grid: np.ndarray,\n",
        "                    positions: Iterable[Tuple[int, int]]):\n",
        "    GameOfLifeGrid._set_grid_state(grid, CellState.ALIVE, positions)\n",
        "\n",
        "\n",
        "  @staticmethod\n",
        "  def _set_as_dead(grid: np.ndarray,\n",
        "                   positions: Iterable[Tuple[int, int]]):\n",
        "    GameOfLifeGrid._set_grid_state(grid, CellState.DEAD, positions)\n",
        "\n",
        "\n",
        "  def set_as_alive(self,\n",
        "                   positions: Iterable[Tuple[int, int]]):\n",
        "    GameOfLifeGrid._set_as_alive(self.grid, positions)\n",
        "\n",
        "\n",
        "  def set_as_dead(self,\n",
        "                  positions: Iterable[Tuple[int, int]]):\n",
        "    GameOfLifeGrid._set_as_dead(self.grid, positions)\n",
        "\n",
        "\n",
        "  def initialize_randomly(self,\n",
        "                          num_alive_cells: int):\n",
        "    self.grid = np.zeros(self.grid.shape, dtype=int)\n",
        "    all_positions = list(iter(self))\n",
        "    np.random.shuffle(all_positions)\n",
        "    alive_positions = all_positions[:num_alive_cells]\n",
        "    self.set_as_alive(alive_positions)\n",
        "\n",
        "\n",
        "  def __iter__(self):\n",
        "    indices = np.indices(self.grid.shape)\n",
        "    index_pairs = zip(indices[0].flatten(), indices[1].flatten())\n",
        "    return iter(index_pairs)\n",
        "\n",
        "\n",
        "  def get_state(self,\n",
        "                position: Tuple[int, int]) -> int:\n",
        "    return self.grid[position[0]][position[1]]\n",
        "\n",
        "\n",
        "  def get_neighbours(self,\n",
        "                     position: Tuple[int, int]) -> List[Tuple[int, int]]:\n",
        "    xmax, ymax = self.grid.shape\n",
        "    x, y = position\n",
        "    neighbours_x_pos = [max(0, x - 1), x, min(x + 1, xmax - 1)]\n",
        "    neighbours_y_pos = [max(0, y - 1), y, min(y + 1, ymax - 1)]\n",
        "    neighbours = product(neighbours_x_pos, neighbours_y_pos)\n",
        "    neighbours = set(neighbours)\n",
        "    neighbours = [neighbour for neighbour in neighbours if np.any(neighbour != position)]\n",
        "    return neighbours\n",
        "\n",
        "\n",
        "  def num_live_neighbours(self,\n",
        "                          position: Tuple[int, int]) -> int:\n",
        "    neighbours = self.get_neighbours(position)\n",
        "    neighbour_states = np.array([self.get_state(neighbour) for neighbour in neighbours])\n",
        "    return np.sum(neighbour_states == CellState.ALIVE.value)\n",
        "\n",
        "\n",
        "  def update(self, num_steps, inplace=False) -> List['GameOfLifeGrid']:\n",
        "    curr_grid = self\n",
        "    grids = [curr_grid]\n",
        "    for step in range(num_steps):\n",
        "      curr_grid = curr_grid.update_one_step(inplace=inplace)\n",
        "      grids.append(curr_grid)\n",
        "    return grids\n",
        "\n",
        "\n",
        "  def update_one_step(self, inplace=False) -> 'GameOfLifeGrid':     \n",
        "    # Make a copy of the current grid.\n",
        "    new_grid = self.grid.copy()\n",
        "\n",
        "    # Get alive and dead cell positions.\n",
        "    live_positions = np.argwhere(self.grid == CellState.ALIVE)\n",
        "    dead_positions = np.argwhere(self.grid == CellState.DEAD)\n",
        "\n",
        "    # Apply Game of Life rules here.\n",
        "    # Live cells can die.\n",
        "    for live_position in live_positions:\n",
        "      if self.num_live_neighbours(live_position) not in [2, 3]:\n",
        "        GameOfLifeGrid._set_as_dead(new_grid, live_position)\n",
        "\n",
        "    # Dead cells can be resurrected.\n",
        "    for dead_position in dead_positions:\n",
        "      if self.num_live_neighbours(dead_position) == 3:\n",
        "        GameOfLifeGrid._set_as_alive(new_grid, dead_position)\n",
        "    \n",
        "    # Update inplace, or make a new grid?\n",
        "    if inplace:\n",
        "      self.grid = new_grid\n",
        "      return self\n",
        "\n",
        "    else:\n",
        "      new_grid_instance = type(self)(grid_dims=self.grid.shape)\n",
        "      new_grid_instance.grid = new_grid\n",
        "      return new_grid_instance\n",
        "\n",
        "  # Returns an adjacency matrix and node features from the current grid.\n",
        "  def get_features(self):\n",
        "    # Map cell positions to indices.\n",
        "    index_map = {pos: index for index, pos in enumerate(self)}\n",
        "    \n",
        "    # Compute adjacency matrix.\n",
        "    num_nodes = self.grid.size\n",
        "    adjacency_matrix = np.zeros((num_nodes, num_nodes))\n",
        "    for pos, index in index_map.items():\n",
        "      neighbours_pos = self.get_neighbours(pos)\n",
        "      neighbours_indexes = [index_map[neigh_pos] for neigh_pos in neighbours_pos] \n",
        "      adjacency_matrix[index][neighbours_indexes] = 1\n",
        "  \n",
        "    # List node features.\n",
        "    node_features = np.array([self.get_state(pos) for pos in index_map]).reshape((num_nodes, 1)).astype('double')\n",
        "\n",
        "    return adjacency_matrix, node_features\n",
        "\n",
        "  def plot(self, fig, ax, show_ticks=False):\n",
        "    ax.imshow(self.grid.T, cmap='Greys', origin='lower')\n",
        "\n",
        "    if show_ticks:\n",
        "      xticks = np.arange(self.grid.shape[0])\n",
        "      yticks = np.arange(self.grid.shape[1])\n",
        "    else:\n",
        "      xticks = []\n",
        "      yticks = []\n",
        "\n",
        "    ax.set_xticks(xticks)\n",
        "    ax.set_yticks(yticks)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u8wNa_BzDdcZ"
      },
      "source": [
        "# Generates input and output node features, from a sequence of grids.\n",
        "def generate_training_data():\n",
        "\n",
        "  # A fixed size grid.\n",
        "  grid = GameOfLifeGrid(grid_dims=grid_dims)\n",
        "  adjacency_matrix = grid.get_features()[0]\n",
        "\n",
        "  all_inp_features = np.zeros((num_grids, grid_size, 1), dtype=np.float32)\n",
        "  all_out_features = np.zeros((num_grids, grid_size, 1), dtype=np.float32)\n",
        "  all_train_masks = np.zeros((num_grids, grid_size, 1), dtype=np.int32)\n",
        "  all_val_masks = np.zeros((num_grids, grid_size, 1), dtype=np.int32)\n",
        "  all_test_masks = np.zeros((num_grids, grid_size, 1), dtype=np.int32)\n",
        "\n",
        "  index = 0\n",
        "  for _, num_alive_cells in enumerate(range(1, grid_size//2 + 1)):\n",
        "    for _ in range(num_grids_per_num_alive_cells):\n",
        "  \n",
        "      # Choose some cells to keep alive, and update the grid.\n",
        "      grid.initialize_randomly(num_alive_cells)\n",
        "      final_grid = grid.update(num_update_steps)[-1]\n",
        "\n",
        "      # Save node features.\n",
        "      _, inp_features = grid.get_features()\n",
        "      _, out_features = final_grid.get_features()\n",
        "\n",
        "      # Train and test masks.\n",
        "      mask = np.zeros(grid_size, dtype=np.int32)\n",
        "      mask[:grid_size//4] = 1\n",
        "      mask[grid_size//4:grid_size//2] = 2\n",
        "      np.random.shuffle(mask)\n",
        "\n",
        "      train_mask = (mask == 1)\n",
        "      val_mask = (mask == 2)\n",
        "      test_mask = (mask == 0)\n",
        "\n",
        "      train_mask = np.expand_dims(train_mask, axis=1)\n",
        "      val_mask = np.expand_dims(val_mask, axis=1)\n",
        "      test_mask = np.expand_dims(test_mask, axis=1)\n",
        "\n",
        "      # Save.\n",
        "      all_inp_features[index] = inp_features\n",
        "      all_out_features[index] = out_features\n",
        "      all_train_masks[index] = train_mask\n",
        "      all_val_masks[index] = val_mask\n",
        "      all_test_masks[index] = test_mask\n",
        "      index += 1\n",
        "\n",
        "  return adjacency_matrix, all_inp_features, all_out_features, all_train_masks, all_val_masks, all_test_masks"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-bAUqRBiDP11"
      },
      "source": [
        "# Seed PRNG.\n",
        "np.random.seed(data_generation_seed)\n",
        "tf.random.set_seed(data_generation_seed)\n",
        "\n",
        "adjacency_matrix, all_inp_features, all_out_features, all_train_masks, all_val_masks, all_test_masks = generate_training_data()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V3W5z_vRzzbX"
      },
      "source": [
        "# Convert to tf.Datasets, to use batching.\n",
        "dataset = tf.data.Dataset.from_tensor_slices((all_inp_features, all_out_features, all_train_masks, all_val_masks))\n",
        "dataset = dataset.repeat(num_epochs).shuffle(num_epochs * num_grids, reshuffle_each_iteration=True).batch(batch_size)\n",
        "\n",
        "# Index into LOG_DIR using the current date and time.\n",
        "current_time = dt.datetime.now().strftime(\"%Y-%m-%d/%H:%M:%S\")\n",
        "train_log_dir = LOG_DIR + '/' + current_time + '/train'\n",
        "val_log_dir = LOG_DIR + '/' + current_time + '/val'\n",
        "train_saved_dir = SAVED_DIR + '/' + model_type\n",
        "\n",
        "# Create separate writers for the training and test sets.\n",
        "train_summary_writer = tf.summary.create_file_writer(train_log_dir)\n",
        "val_summary_writer = tf.summary.create_file_writer(val_log_dir)\n",
        "\n",
        "# Define our metrics.\n",
        "train_loss_metr = tf.keras.metrics.Mean('train_loss', dtype=tf.float64)\n",
        "val_loss_metr = tf.keras.metrics.Mean('val_loss', dtype=tf.float64)\n",
        "train_accuracy_metr = tf.keras.metrics.CategoricalAccuracy('train_accuracy')\n",
        "val_accuracy_metr = tf.keras.metrics.CategoricalAccuracy('val_accuracy')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e_qcnQ_vDXTD"
      },
      "source": [
        "%tensorboard --logdir {LOG_DIR}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "epycnXSFTG_V"
      },
      "source": [
        "# Define the model.\n",
        "def build_model(model_type, m):\n",
        "  if model_type == 'gcn-degree-i':\n",
        "    return snt.Sequential([\n",
        "        GCNLayer(input_dims=1, output_dims=2*m, normalization='degree-i', name='gcn_layer_1'),\n",
        "        tf.nn.relu,\n",
        "        snt.Linear(output_size=m, name='point_update'),\n",
        "        tf.nn.relu,\n",
        "        snt.Linear(output_size=2, name='predict'),\n",
        "    ])\n",
        "  elif model_type == 'gcn-no-norm':\n",
        "    return snt.Sequential([\n",
        "        GCNLayer(input_dims=1, output_dims=2*m, normalization='none', name='gcn_layer_1'),\n",
        "        tf.nn.relu,\n",
        "        snt.Linear(output_size=m, name='point_update'),\n",
        "        tf.nn.relu,\n",
        "        snt.Linear(output_size=2, name='predict'),\n",
        "    ])\n",
        "  elif model_type == 'cnn':\n",
        "    return snt.Sequential([\n",
        "        snt.Conv2D(kernel_shape=3, output_channels=2*m, name='cnn_layer_1'),\n",
        "        tf.nn.relu,\n",
        "        snt.Conv2D(kernel_shape=1, output_channels=m, name='point_update'),\n",
        "        tf.nn.relu,\n",
        "        snt.Linear(output_size=2, name='predict'),\n",
        "    ])\n",
        "  else:\n",
        "    raise ValueError('Invalid model_type.')\n",
        "\n",
        "# Seed before building the model.\n",
        "np.random.seed(model_seed)\n",
        "tf.random.set_seed(model_seed)\n",
        "\n",
        "# Construct model.\n",
        "print('Chosen model: %s' % model_type)\n",
        "print('Saved model will be at %s' % train_saved_dir)\n",
        "model = build_model(model_type, m)\n",
        "\n",
        "# Reshape for the CNNs.\n",
        "if model_type == 'cnn':\n",
        "  inp_shape = (8, 8, 1)\n",
        "else:\n",
        "  inp_shape = (64, 1)\n",
        "\n",
        "# Call model before evaluating its size.\n",
        "model(all_inp_features[0].reshape(-1, *inp_shape))\n",
        "model_size = np.sum([np.prod(variable.shape) for variable in model.variables])\n",
        "print('Model Size: %d parameters.' % model_size)\n",
        "  \n",
        "# Optimizer.\n",
        "learning_rate = 5e-2\n",
        "opt = snt.optimizers.SGD(learning_rate)\n",
        "\n",
        "\n",
        "# Save model.\n",
        "def save_with_prefix(prefix):\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec([None, *inp_shape])])\n",
        "  def predict(x):\n",
        "    return tf.argmax(model(x), axis=-1)\n",
        "\n",
        "  to_save = snt.Module()\n",
        "  to_save.predict = predict\n",
        "  to_save.all_variables = list(model.variables)\n",
        "  tf.saved_model.save(to_save, train_saved_dir + '/' + str(prefix))\n",
        "\n",
        "\n",
        "# Update after one batch of training data.\n",
        "def train_step(step, inp_features, out_features, train_mask, val_mask):\n",
        "\n",
        "  # Reshape.\n",
        "  train_mask = tf.reshape(train_mask, -1)\n",
        "  val_mask = tf.reshape(val_mask, -1)\n",
        "  out_features = tf.reshape(out_features, -1)\n",
        "  inp_features = tf.reshape(inp_features, (-1, *inp_shape))\n",
        "\n",
        "  # Apply masks.\n",
        "  train_indices = tf.reshape(tf.where(train_mask), -1)\n",
        "  val_indices = tf.reshape(tf.where(val_mask), -1)\n",
        "  train_labels = tf.cast(tf.gather(out_features, train_indices), tf.int32)\n",
        "  val_labels = tf.cast(tf.gather(out_features, val_indices), tf.int32)\n",
        "\n",
        "  # Watch gradients for these computations!\n",
        "  with tf.GradientTape() as tape:\n",
        "\n",
        "    # Get predictions from model.\n",
        "    logits = model(inp_features)\n",
        "    logits = tf.reshape(logits, (-1, 2))\n",
        "\n",
        "    # Compute training loss.\n",
        "    train_logits = tf.gather(logits, train_indices)\n",
        "    train_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_logits, labels=train_labels)\n",
        "    train_loss = tf.math.reduce_mean(train_loss)\n",
        "    train_loss_metr(train_loss)\n",
        "\n",
        "  # Update parameters.\n",
        "  params = model.trainable_variables\n",
        "  grads = tape.gradient(train_loss, params)\n",
        "  opt.apply(grads, params)\n",
        "\n",
        "  # Compute validation loss.\n",
        "  val_logits = tf.gather(logits, val_indices)\n",
        "  val_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=val_logits, labels=val_labels)\n",
        "  val_loss = tf.math.reduce_mean(val_loss)\n",
        "  val_loss_metr(val_loss)\n",
        "\n",
        "  # Compute training accuracy.\n",
        "  train_labels_one_hot = tf.one_hot(train_labels, depth=2)\n",
        "  train_preds_one_hot = tf.one_hot(tf.math.argmax(train_logits, axis=1), depth=2)\n",
        "  train_accuracy_metr(train_labels_one_hot, train_preds_one_hot)\n",
        "\n",
        "  # Compute validation accuracy.\n",
        "  val_labels_one_hot = tf.one_hot(val_labels, depth=2)\n",
        "  val_preds_one_hot = tf.one_hot(tf.math.argmax(val_logits, axis=1), depth=2)\n",
        "  val_accuracy_metr(val_labels_one_hot, val_preds_one_hot)\n",
        "  \n",
        "  # Write to logs.\n",
        "  with train_summary_writer.as_default():\n",
        "    tf.summary.scalar('loss', train_loss_metr.result(), step=step)\n",
        "  with val_summary_writer.as_default():\n",
        "    tf.summary.scalar('loss', val_loss_metr.result(), step=step)\n",
        "\n",
        "  with train_summary_writer.as_default():\n",
        "    tf.summary.scalar('accuracy', train_accuracy_metr.result(), step=step)\n",
        "  with val_summary_writer.as_default():\n",
        "    tf.summary.scalar('accuracy', val_accuracy_metr.result(), step=step)\n",
        "\n",
        "  # Reset metrics' states, otherwise we would be averaging over them.\n",
        "  train_loss_metr.reset_states()\n",
        "  val_loss_metr.reset_states()\n",
        "  train_accuracy_metr.reset_states()\n",
        "  val_accuracy_metr.reset_states()\n",
        "\n",
        "  return train_loss, val_loss\n",
        "\n",
        "\n",
        "val_losses = []\n",
        "for step, (inp_features, out_features, train_mask, val_mask) in enumerate(dataset):\n",
        "  train_loss, val_loss = train_step(step, inp_features, out_features, train_mask, val_mask)\n",
        "  val_losses.append(val_loss.numpy())\n",
        "\n",
        "  if step % 1000 == 0:\n",
        "    print('Step %d completed.' % step)\n",
        "\n",
        "  if step > early_stopping_lim and np.min(val_losses[-early_stopping_lim:]) > val_losses[-early_stopping_lim - 1]:\n",
        "    print(val_losses, np.min(val_losses[-early_stopping_lim:]), val_losses[-early_stopping_lim - 1])\n",
        "    print('Early-stopping at step %d.' % step)\n",
        "    break\n",
        "\n",
        "save_with_prefix(model_size) "
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}